# -*- coding: utf-8 -*-
"""
Created on Tue Apr 23 15:42:17 2019

@author: jlani
"""

import numpy as np
import itertools

class PermTester():
    '''Takes consumption data C and price data P and performs a permutation test of 
    rationality using the Afriat Efficiency Index. 
    
    Let N be the number of observations and L the number of goods. C and P are N x L matrices.
    
    To run the test call an instance of the PermTester passing two arguments. 
    1. The first argument is either garp or qlrp depending on whether you want to run vanilla garp or the quasilinear 
    version. 
    2. The second argument is either 'c', 'b', 'r' depending on whether you would like to 
    shuffle consumption bundles, budget shares, or consumption rays.
    '''
    
    def __init__(self, C, P):
        '''C and P are N x L matrices where C[n,l] represents the amount of good l consumed
        in period n. P[n,l] represents the price fo good l in period n.'''
        
        self.N = C.shape[0]
        
        #Make a value matrix. It's a T x T matrix where element
        #n, m is the value of the bundle of period n evaluated at price vector in period m.
        self.V = C @ P.T
        
        #Make a value matrix for budget share swaps. It's a N x N x N matrix where element
        #(t,s,r) is the value of the consumption bundle formed by giving the budget 
        #shares from period r to the budget set of period t and then 
        #valuing the bundle at the price from period s. 
        #That is, VB[t,s,r] = ( (p^t c^t) / (p^r c^r) ) ( sum_l^L p_l^s p_l^r c_l^r / p_l^t )
        
        #-----------------------
        #SHARES CODE
        #-----------------------
        #Enable these lines to allow for the budget shares version.
        #It is expensive so I disabled it.
        #self.VB = [[[
        #          ( ( P[t] @ C[t] ) / ( P[r] @ C[r] ) ) *
        #          sum( ( P[s] * P[r] * C[r] ) / P[t] )
        #          for r in range(self.N)]
        #          for s in range(self.N)]
        #          for t in range(self.N)]
        #self.VB = np.array(self.VB)
        #-----------------------
        #END SHARES CODE
        #-----------------------        
        
        #VB[t,s,perm[t]] shall return the value of the budget shares of time perm[t] in the budget of period t
        #valued using the prices of period s. This is the value needed to do revealed preferences.
        
    def __call__(self,test_name,mode):
        '''Runs the permutation test. test_name is either garp or qlrp.
        mode is either 'c', 'b', 'r' indicating that shuffling takes place 
        over consumption bundles, budget shares, or budget rays.'''
        
        self.mode = mode #Save the mode.
        self.perm = list(range(self.N)) #Make sure we calc e-score with unpermuted data.
        e = self._get_e(test_name) #Get the threshold e for the test.
        
        
        #ties are disabled for speed.
        #ties = 0
        #epsilon = 2. ** -15 #For testing ties
        
        permer = self._permutor() #Iterator for permutations.
        
        total = 0 #Total number of permutations run.
        passes = 0 #Number of permutations doing weakly better than the actual data.
        #Gen permutations and record how many have weakly higher e-scores
        for perm in permer:
            self.perm = perm #Set the perm object.
            
            #Returns True if the permuted data at least as good as the real data.
            test_result = self._run_test(test_name,e) 
            if test_result:
                passes = passes + 1
                #-------------------
                #TIES
                #-------------------
                #Can we detect a tie?
                #if (e + epsilon > 1.) or (not self._run_test(test_name, e + epsilon)):
                    #Means the test passes with e but fails with e + epsilon
                    #ties = ties + 1
                #-------------------
                #END TIES
                #-------------------
                    
            total = total + 1
            if total == 1000 and passes / total > 0.2: break #If we're getting no where then break now
        #self.ties_percent = ties / total #Ties code
        return passes / total
    

    #=========================================================================
    #TEST SELECTORS
    #=========================================================================    
    def _get_e(self,test_name):
        '''Gets the e-score based on the test name passed.'''
        
        if test_name == 'garp': 
            e = self._garp_find_e()
        elif test_name == 'qlrp':
            e = self._qlrp_find_e()        
        else:
            print(test_name + ' is an invalid name.')
            raise Exception('Invalid Type Name.')
        return e
    
    def _run_test(self,test_name,e):
        '''Runs the appropriate test (garp or qlrp).'''
        
        if test_name == 'garp':
            result = self._garp(e)
        elif test_name == 'qlrp':
            result = self._qlrp(e)
        else:
            print(test_name + ' is an invalid name.')
            raise Exception('Invalid Test Name.')
        return result
    
    #=========================================================================
    #DIRECT RP METHODS
    #=========================================================================
    def _rp(self,n,m,e):
        '''Returns True if the bundle in ob n is direct e-rp to bundle of ob m (post shuffle). 
        Otherwise False.'''
        perm = self.perm; V = self.V; mode = self.mode
        
        if mode == 'c': 
            result = e * V[perm[n],n] > V[perm[m],n]
        elif mode == 'b':
            result = e * self.VB[n,n,perm[n]] > self.VB[m,n,perm[m]]  #Note VB[n,n,perm[n]] = VB[n,n,n]
        elif mode == 'r':
            result = e * V[ n,n ] > ( V[ perm[m],n ] * V[m,m] ) / V[ perm[m],m ]
        else:
            raise Exception('mode has taken a forbidden value.')
        return result

    def _rp_weight(self,n,m,e):
        '''Returns e * P[n] * C[n] - P[n] * C[m] after the shuffle.'''
        perm = self.perm; V = self.V; mode = self.mode
        
        if mode == 'c': 
            result = e * V[perm[n],n] - V[perm[m],n]
        elif mode == 'b':
            result = e * self.VB[n,n,perm[n]] - self.VB[m,n,perm[m]]  #Note VB[n,n,perm[n]] = VB[n,n,n]
        elif mode == 'r':
            result = e * V[ n,n ] - ( V[ perm[m],n ] * V[m,m] ) / V[ perm[m],m ]
        else:
            raise Exception('mode has taken a forbidden value.')
        return result

    #=========================================================================
    #GARP METHODS
    #=========================================================================
    def _garp( self, e ):
        '''
        Determines if the current data with shuffling satisfies e-garp. 
        Returns True if it does otherwise False.
        '''
        
        N = self.N #Number of obs
        
        #untouched is an N entry list where untouched[n] is True if we have
        #not begun working on ob n.
        untouched = N * [True]
        
        #working is an N entry list where working[n] is True if we have 
        #moved n from untouched but have not finished with this ob.
        working = N * [False]
        
        #Try to find a cycle for each untouched ob.
        for current in range(N):
            if untouched[current]:
                if self._dfs(current,untouched,working,e):
                    #found a cycle
                    return False
        #no cycles
        return True
        
            
    def _dfs(self,current,untouched,working,e):
        '''Depth first search of the graph. Returns True if a cycle is found.'''
        
        #Move current from untouched to working.
        untouched[current] = False
        working[current] = True
        
        #Generator for children of current
        childgen = ( n for n in range(self.N) if self._rp(current,n,e) )
        
        #Investigate each child.
        for child in childgen:
            if untouched[child]:
                #Child is untouched so run the search.
                if self._dfs( child, untouched, working, e ):
                    return True
            elif working[child]:
                #Found a cycle involving current and child.
                return True
        #We have finished with current.
        working[current] = False
        return False
        
    def _garp_find_e( self ):
        '''Finds the threshold e at which the data just passes garp.'''
        
        high_e = 1.0
        low_e = 0.0
        e = 1.0
        #Will binary search until we have narrowed the window enough.
        while high_e - low_e >= 2**(-15):
            if self._garp(e):
                low_e = e #Our e passes so raise the lower bound.
            else:
                high_e = e #Our e failed so lower the upper bound.
            e = (high_e + low_e) / 2.0 #New e is the midpoint.
        return e

    #=========================================================================
    #QLRP METHODS
    #=========================================================================    
    
    def _qlrp(self,e):
        '''
        Tests if the data satisfies e-qlrp. Returns True if so otherwise False.
        '''
        
        N = self.N
        #W will keep track of path length. 
        #W[j,k] denotes the weight of the edge from k to j.
        W = [ [ self._rp_weight(j,k,e) for k in range(N) ] for j in range(N) ]
        W = np.array(W)
        
        #Find cycles in W
        for i in range(N):
            for j in range(N):
                for k in range(N):
                    W[j,k] = max( W[j,k], W[j,i] + W[i,k] )
                if W[j,j] > 0.: return False #Found a cycle
        return True
    
    def _qlrp_find_e( self ):
        '''Finds the threshold e at which the data just passes qlrp.'''
        
        high_e = 1.0
        low_e = 0.0
        e = 1.0
        #Will binary search until we have narrowed the window enough.
        while high_e - low_e >= 2**(-15):
            if self._qlrp(e):
                low_e = e #Our e passes so raise the lower bound.
            else:
                high_e = e #Our e failed so lower the upper bound.
            e = (high_e + low_e) / 2.0 #New e is the midpoint.
        return e

    #=========================================================================
    #PERMUTATION METHODS
    #=========================================================================  
    def _permutor(self):
        '''Generator for permutations for the test.
        If there are more than 8 obs then we generate 1,000 permutations. 
        Else we iterate through every possible permutation.'''
        if self.N >= 8:
            for count in range( int(1e4) ):
                perm = np.random.permutation(self.N)
                yield perm
        else:
            for perm in itertools.permutations( range(self.N) ):
                yield perm
    
    
    

    



